rm(list = ls())

library(bsts)
library(stargazer)
library(ggpubr)
library(tidyverse)

set.seed(pi + exp(1))

burn_in <- 200
num_iter <- 1000

quick_summary <- function(u) c(mean(u), sd(u), quantile(u, c(0.025, 0.975)))


semilocal_trend <- function(y) {
  ss <- AddSeasonal(list(), y, nseasons = 7)
  ss <- AddSemilocalLinearTrend(ss, y)
  return(ss)
}


quiet_bsts <- quietly(bsts)

ar_trend <- function(p) {
  arp_trend <- function(y) {
    ss <- AddSeasonal(list(), y, nseasons = 7)
    ss <- AddAr(ss, y, lags = p)
    ss <- AddStaticIntercept(ss, y)
    return(ss)
  }
  return(arp_trend)
}



bsts_analysis <- function(trend_constructor, df, outcome, outcome_name,
  figure_name, pid, treatment_day, horizon_len) {
  temp_df <- df %>% filter(party == pid &
    days_after <= treatment_day + horizon_len) %>%
    arrange(days_after)
  y <- temp_df[[outcome]]
  pre_y <- y[temp_df$days_after < 0]
  post_y <- y[temp_df$days_after >= 0]

  ss <- trend_constructor(pre_y)
  mod <- quiet_bsts(pre_y, state.specification = ss,
    niter = num_iter + burn_in)$result

  pdf(NULL) # just want the plot data
  pre_pred <- plot(mod, burn = burn_in)
  dev.off() # necessary after doing pdf(NULL)
  
  pre_pred <- apply(pre_pred, 2, quick_summary)
  post_pred <- predict(mod, burn = burn_in, horizon = horizon_len + 1)$distribution
  post_pred <- apply(post_pred, 2, quick_summary)
  
  pred_df <- data.frame(days_after = temp_df$days_after,
    Date = as.Date(temp_df$date),
    Y = c(pre_pred[1, ], post_pred[1, ]),
    Y_se = c(pre_pred[2, ], post_pred[2, ]),
    Y_lwr = c(pre_pred[3, ], post_pred[3, ]),
    Y_upr = c(pre_pred[4, ], post_pred[4, ]),
    Source = "Prediction")
  
  obs_df <- pred_df
  obs_df$Y <- c(pre_y, post_y)
  obs_df$Y_se <- NA
  obs_df$Y_lwr <- NA
  obs_df$Y_upr <- NA
  obs_df$Source <- "Observed"
  
  plot_df <- rbind(pred_df, obs_df)

  effect_df <- pred_df %>% dplyr::select(-Source)
  effect_df$Y <- obs_df$Y - pred_df$Y
  effect_df$Y_se <- pred_df$Y_se
  effect_df$Y_lwr <- obs_df$Y - pred_df$Y_upr
  effect_df$Y_upr <- obs_df$Y - pred_df$Y_lwr

  g1 <- plot_df %>% filter(days_after >= -30) %>%
    ggplot(aes(x = Date, y = Y, color = Source)) +
    geom_vline(xintercept = as.Date(temp_df$date[temp_df$days_after == 0]),
      lty = 3, col = "gray50") +
    geom_line(aes(lty = Source), lwd = 1.0) +
    geom_ribbon(aes(ymin = Y_lwr, ymax = Y_upr, color = NULL, fill = Source),
      alpha = 0.15) +
    scale_fill_manual(values = c("gray50", "gray20")) +
    scale_color_manual(values = c("gray50", "gray20")) +
    labs(y = outcome_name) +
    ggtitle(paste0("Mean: ", round(mean(pre_y), 2), "; SD: ", round(sd(pre_y), 2))) +
    theme_classic()

  g2 <- effect_df %>% filter(days_after >= -30) %>%
    ggplot(aes(x = Date, y = Y, ymin = Y_lwr, ymax = Y_upr)) +
    geom_hline(yintercept = 0, lty = 1, col = "gray50") +
    geom_vline(xintercept = as.Date(temp_df$date[temp_df$days_after == 0]),
      lty = 3, col = "gray50") +
    geom_pointrange(color = "gray20") +
    labs(y = paste0("Obs. - Pred.: ", outcome_name)) +
    theme_classic()

  # g <- ggarrange(g1, g2, nrow = 2, ncol = 1, common.legend = TRUE, legend = "top")
  # ggsave(figure_name, plot = g)

  p <- ggarrange(g1, g2, nrow = 2, ncol = 1, common.legend = TRUE, legend = "top")
  ggexport(p, filename = figure_name, height = 7, width = 7)

  return(effect_df %>% filter(days_after >= 0))
}

forecast_effect <- function(event, event_name) {
  panel_folder <- "../results/"
  table_folder <- "../results/tables/"
  figure_folder <- "../results/figures/"
  
  panel <- read_csv(paste0(panel_folder, event, "_panel.csv"))
  panel <- panel %>% mutate(lpu = portal_frac < 0.5)

  weights <- read_csv("../data/weights.csv")  
  panel <- panel %>% left_join(weights, by = "caseid")

  party_panel <- panel %>% group_by(party, date, days_after) %>%
    summarise(
      num_subjects = n(),
      avg_urls = mean(num_urls),
      wavg_urls = sum(num_urls * weight, na.rm = TRUE) / sum(weight, na.rm = TRUE),
      avg_portals = mean(num_portals),
      wavg_portals = sum(num_portals * weight, na.rm = TRUE) / sum(weight, na.rm = TRUE),
      avg_legacy = mean(num_legacy),
      wavg_legacy = sum(num_legacy * weight, na.rm = TRUE) / sum(weight, na.rm = TRUE),
      avg_online = mean(num_online),
      wavg_online = sum(num_online * weight, na.rm = TRUE) / sum(weight, na.rm = TRUE),
      avg_polling = mean(num_polling),
      wavg_polling = sum(num_polling * weight, na.rm = TRUE) / sum(weight, na.rm = TRUE),
      frac_once = mean(num_urls > 0),
      wfrac_once = sum(weight * (num_urls > 0), na.rm = TRUE) / sum(weight, na.rm = TRUE),
      avg_b_align = mean(date_b_align, na.rm = TRUE),
      wavg_b_align = sum(date_b_align * weight, na.rm = TRUE) / sum(weight, na.rm = TRUE),
      ) %>% 
    ungroup()

  party_nums <- party_panel %>% group_by(party) %>%
    summarise(num_subjects = mean(num_subjects)) %>% 
    pull(num_subjects) # the mean is sort of trivial.. it shouldn't vary

  outcomes <- c("avg_urls", "wavg_urls",
    "avg_portals", "wavg_portals", 
    "avg_online", "wavg_online", 
    "avg_legacy", "wavg_legacy", 
    "avg_polling", "wavg_polling", 
    "frac_once", "wfrac_once",
    "avg_b_align", "wavg_b_align")
  outcome_names <- c("No. URLs", "No. URLs",
    "No. Portals", "No. Portals",
    "No. Online", "No. Online",
    "No. Legacy", "No. Legacy",
    "No. Polling", "No. Polling",
    "Fraction Visit", "Fraction Visit",
    "B-Align", "B-Align")
  party_names <- c("DEM", "REP")
  parties <- c("dem", "rep")
  horizon_len <- 5
  
  tab <- list()
  for (p in 1:length(parties)) {
    tab[[p]] <- list()
    all_ests <- list()
    party_means <- array(NA, dim = length(outcomes))
    for (o in 1:length(outcomes)) {
      temp <- sd(party_panel[[outcomes[o]]][party_panel$party == parties[p] & party_panel$days_after < 0])

      party_means[o] <- mean(party_panel[[outcomes[o]]][party_panel$party == parties[p] & party_panel$days_after < 0])

      est <- bsts_analysis(semilocal_trend, party_panel, outcomes[o],
        paste0(outcome_names[o], " (", party_names[p], ")"),
        paste0(figure_folder, "wforecast_", event, "_", parties[p], "_", outcomes[o], ".pdf"), parties[p], 0, horizon_len)

      stars <- with(est, abs(Y / Y_se) >= qnorm(0.975))
      est <- est %>% dplyr::select(Y, Y_se) %>% as.matrix() %>% t() %>%
        signif(2) %>%  
        as.character() %>%
        matrix(nrow = 2)

      est[1, which(stars)] <- paste0(est[1, which(stars)], "**")
      est[2, ] <- paste0("(", est[2,], ")")
      all_ests[[o]] <- est %>% matrix(ncol = 1)
    }
    all_ests <- do.call(cbind, all_ests)   
    all_ests <- rbind(paste0("[", party_means %>% round(2), "]"), all_ests)
    all_ests <- cbind(c("[Mean]", sapply(0:horizon_len, function(u) c(u, "")) %>%
      matrix(ncol = 1) %>% c()), all_ests)
    all_ests <- cbind(c(party_names[p],
      paste0("(n = ", party_nums[p], ")"),
      rep("", 1 + horizon_len * 2)), all_ests)

    colnames(all_ests) <- c("Party", "Day", outcome_names)
    tab[[p]] <- all_ests
  }
  tab <- do.call(rbind, tab)

  stargazer(tab, type = "latex", style = "ajps", align = TRUE,
    out = paste0(table_folder, "wforecast_", event, ".tex"),
    title = paste0("Effects of ", event_name))
}

forecast_effect("PLACEBO1", "Placebo 1")
forecast_effect("PLACEBO2", "Placebo 2")
forecast_effect("PLACEBO3", "Placebo 3")
forecast_effect("AH", "Access Hollywood")
forecast_effect("CL", "Comey Letter")
forecast_effect("FALL", "Clinton's Fall")
forecast_effect("SEP_DEBATE", "Sep. 26 Debate")